# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipelines to generate datasets for the alignment and homology tasks."""

import collections
import functools
import itertools
import random
from typing import Callable, Dict, Iterable, Iterator, Optional

import apache_beam as beam
import numpy as np
import tensorflow as tf

from dedal.preprocessing import alignment
from dedal.preprocessing import schemas
from dedal.preprocessing import schemas_lib
from dedal.preprocessing import types
from dedal.preprocessing import utils


# Type aliases
Record = types.Record
Array = np.ndarray
PRNG = random.Random


# Constants
ALIGNMENT_FIELDS = (
    'key', 'pfam_acc', 'clan_acc', 'seq_start', 'seq_end', 'passed_qc',
    'sequence', 'gapped_sequence',
    'hmm_hit_seq_start', 'hmm_hit_seq_end', 'hmm_hit_clan_acc',
    'other_hit_seq_start', 'other_hit_seq_end', 'other_hit_type_id')
OTHER_REGIONS = (
    'coiled_coil', 'disorder', 'low_complexity', 'sig_p', 'transmembrane')
CONFOUNDING_REGIONS = ('clans',)
SPLITS = ('train', 'iid_validation', 'ood_validation', 'iid_test', 'ood_test')
PREFIXES = ('n', 'c')  # Indices of the two flanks (N-terminus, C-terminus).
SUFFIXES = ('x', 'y')  # Indices of the two regions in the pairs.
AA_PROB_TABLE = {'A': 0.0832177442683677,
                 'C': 0.013846953304580488,
                 'D': 0.05746458363960445,
                 'E': 0.0662881179936497,
                 'F': 0.03796971998594428,
                 'G': 0.06893382361885941,
                 'H': 0.021288200487753935,
                 'I': 0.05563140763547133,
                 'K': 0.05514607272883951,
                 'L': 0.09466292698151958,
                 'M': 0.021795269979457313,
                 'N': 0.042539379191607996,
                 'P': 0.04820225072073993,
                 'Q': 0.0397779223030366,
                 'R': 0.05446797313841446,
                 'S': 0.07095409709482754,
                 'T': 0.05849503653768657,
                 'V': 0.06729418111891682,
                 'W': 0.011878265406494127,
                 'Y': 0.030146073864228278}


def load_pfam_accs(path):
  """Returns a mapping from Pfam accessions to unique integer indices.

  Args:
    path: The path to a plain text file containing one Pfam accession per line.

  Returns:
  A dictionary mapping Pfam accessions to integer-valued identifiers numbered
  between 0 (inclusive) and the total number of distinct Pfam accessions
  (exclusive).
  """
  pfam_acc_to_index = {}
  with tf.io.gfile.GFile(path, 'r') as f:
    for i, line in enumerate(f):
      pfam_acc = line.strip()
      if pfam_acc in pfam_acc_to_index:
        raise ValueError(f'Key {pfam_acc} is duplicated in {path}.')
      pfam_acc_to_index[pfam_acc] = i
  return pfam_acc_to_index


class ReadParsedPfamData(beam.PTransform):
  """Reads TSV data generated by Step 1. of the preprocessing pipeline."""

  def __init__(
      self,
      file_pattern,
      dataset_splits_path,
      fields_to_keep = None,
      with_flank_seeds = False,
      max_len = None,
      filter_by_qc = True,
  ):
    self.file_pattern = file_pattern
    self.dataset_splits_path = dataset_splits_path
    self.fields_to_keep = fields_to_keep
    self.with_flank_seeds = with_flank_seeds
    self.max_len = max_len
    self.filter_by_qc = filter_by_qc
    self._schema_cls = (schemas.ExtendedParsedPfamRow if self.with_flank_seeds
                        else schemas.ParsedPfamRow)

  def expand(
      self,
      root,
  ):
    read_alignment_data_cls = functools.partial(
        schemas_lib.ReadFromTable,
        schema_cls=self._schema_cls,
        key_field='key',
        skip_header_lines=1,
        fields_to_keep=self.fields_to_keep)

    read_dataset_splits_cls = functools.partial(
        schemas_lib.ReadFromTable,
        schema_cls=schemas.DatasetSplits,
        key_field='key',
        skip_header_lines=1,
        fields_to_keep=('split',))

    # Reads Pfam data from the preceding pipeline. Each element in the output
    # `PCollection` represents a different Pfam region.
    regions = root | 'ReadRegions' >> read_alignment_data_cls(self.file_pattern)

    # Optionally, removes any Pfam regions that did not pass the quality control
    # checks from the `PCollection`.
    if self.filter_by_qc:
      regions = regions | 'QCFilter' >> beam.Filter(lambda x: x[1]['passed_qc'])

    # Optionally, drops Pfam regions that do not pass the maximum region length
    # filter.
    if self.max_len is not None:
      len_fn = lambda x: x[1]['seq_end'] - x[1]['seq_start'] + 1 <= self.max_len
      regions = regions | 'LengthFilter' >> beam.Filter(len_fn)

    # Reads mapping Pfam region key: split.
    dataset_splits = (
        root
        | 'ReadDatasetSplits' >> read_dataset_splits_cls(
            self.dataset_splits_path))
    # Merges split info from `dataset_splits` into elements `regions` and
    # removes the key of each element, which is no longer needed after merging
    # the two `PCollection`s.
    return (
        {'pfam_regions': regions, 'dataset_splits': dataset_splits}
        | 'MergeSplitInfoByKey' >> schemas_lib.JoinTables(
            left_join_tables=['pfam_regions'])
        | 'RemoveKey' >> beam.Values())


def get_prng(record, global_seed, field_name = 'key'):
  """Generates a per-example pair reproducible PRNG key."""
  return random.Random(
      hash(tuple(record[f'{field_name}_{s}'] for s in SUFFIXES)) + global_seed)


def sample_flank_lengths(
    rng,
    seq_start,
    seq_end,
    max_len,
):
  """Samples length of the N-terminus and C-terminus flanks at random."""
  region_len = seq_end - seq_start + 1  # Endpoints are inclusive.
  max_ctx_len = max_len - region_len
  max_n_len = max_c_len = max_ctx_len
  ctx_len = rng.randint(0, max_ctx_len)
  n_len = rng.randint(max(ctx_len - max_c_len, 0), min(max_n_len, ctx_len))
  c_len = ctx_len - n_len
  return {'n_len': n_len, 'c_len': c_len}


def validate_flank_seeds(
    region_pair,
    indices,
    extra_margin = 0,
    min_overlap = 1,
):
  """Tests if a combination of UniProt flanks is valid for a region pair."""
  # First, verifies that all four flank seeds are non-empty if a flank needs to
  # be generated.
  for p, s in itertools.product(PREFIXES, SUFFIXES):
    flank_len = region_pair[f'{p}_len_{s}']
    idx = indices[f'{p}_{s}']
    key = region_pair[f'{p}_flank_seed_key_{idx}_{s}']
    sequence = region_pair[f'{p}_flank_seed_sequence_{idx}_{s}']
    if flank_len and (not key or not sequence):
      return False

  # Second, retrieves the collection of hmm_hits and other_hits in each flank
  # and checks that there are no shared annotations between the flanks of each
  # sequence.
  flank_hits = collections.defaultdict(set)
  for ann_type in ('hmm_hit_clan_acc',):
    for p, s in itertools.product(PREFIXES, SUFFIXES):
      idx = indices[f'{p}_{s}']
      hits = region_pair[f'{p}_flank_seed_{ann_type}_{idx}_{s}']
      flank_hits[f'{ann_type}_{s}'] |= set(hits)
    if set.intersection(*[flank_hits[f'{ann_type}_{s}'] for s in SUFFIXES]):
      return False

  # Finally, tests if there are shared hmm_hit annotations between the flanks
  # of one sequence and the main region of the other sequence.
  for s1, s2 in zip(SUFFIXES, reversed(SUFFIXES)):
    # TODO(fllinares): pre-compute `hmm_hits`, perhaps at the cost of clarity.
    overlaps = interval_overlaps(
        start=region_pair[f'seq_start_{s1}'] - extra_margin,
        end=region_pair[f'seq_end_{s1}'] + extra_margin,
        ref_starts=np.asarray(region_pair[f'hmm_hit_seq_start_{s1}']),
        ref_ends=np.asarray(region_pair[f'hmm_hit_seq_end_{s1}']))
    hit_ids = np.asarray(region_pair[f'hmm_hit_clan_acc_{s1}'])
    hmm_hits = set(hit_ids[overlaps >= min_overlap])
    hmm_hits.add(region_pair[f'clan_acc_{s1}'])  # Likely redundant.
    if hmm_hits & flank_hits[f'hmm_hit_clan_acc_{s2}']:
      return False

  return True


def pair_flank_seeds(
    rng,
    region_pair,
    extra_margin = 0,
    min_overlap = 1,
):
  """Searches from a valid combination of UniProt flanks for the region pair."""
  num_flank_seeds = schemas.NUM_FLANK_SEEDS
  idx_keys = tuple(f'{p}_{s}' for p, s in itertools.product(PREFIXES, SUFFIXES))
  # Iterates over all possible flank seed pairings in a random order.
  for indices in itertools.product(
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds)):
    indices = dict(zip(idx_keys, indices))
    if validate_flank_seeds(region_pair, indices, extra_margin, min_overlap):
      for p, s in itertools.product(PREFIXES, SUFFIXES):
        if region_pair[f'{p}_len_{s}']:  # Skips empty flanks.
          region_pair[f'{p}_flank_seed_idx_{s}'] = indices[f'{p}_{s}']
      break
  return region_pair


def sample_synthetic_flank(rng, length):
  """Samples `length` chars from `AA_PROB_TABLE` i.i.d.."""
  amino_acids = list(AA_PROB_TABLE.keys())
  probabilities = list(AA_PROB_TABLE.values())
  return ''.join(rng.choices(amino_acids, probabilities, k=length))


def generate_flanks(
    rng,
    region_pair,
    flanks,
    max_len,
    extra_margin = 0,
    min_overlap = 1,
):
  """Extends a pair of Pfam domains adding N and C-terminus flanks."""
  # Samples the lengths of the N-terminus and C-terminus flanks, adding the
  # resulting variables to `region_pair`.
  for s in SUFFIXES:
    out = sample_flank_lengths(
        rng=rng,
        seq_start=region_pair[f'seq_start_{s}'],
        seq_end=region_pair[f'seq_end_{s}'],
        max_len=max_len)
    region_pair.update({f'{k}_{s}': v for k, v in out.items()})

  # When using flanks from UniProt, the choice of which flank "seeds" to pick
  # for the N-terminus and C-terminus ends for each of the two regions must be
  # done jointly. In contrast, the rest of the flank generation pipeline can be
  # done independently for each sequence and flank.
  if flanks == 'uniprot':
    region_pair = pair_flank_seeds(rng, region_pair, extra_margin, min_overlap)

  # Processes each of the two regions, `sequence_x` and `sequence_y`,
  # independently.
  for s in SUFFIXES:
    seq_start = region_pair[f'seq_start_{s}']
    seq_end = region_pair[f'seq_end_{s}']
    # Generates the N-terminus and C-terminus flanks for `sequence_{s}`.
    for p in PREFIXES:
      # Synthetic flanks are obtained by randomly sampling amino acids from the
      # original sequence with replacement, independently for each flank.
      if flanks == 'synthetic':
        flank_acc = 'synth'
        flank_len = region_pair[f'{p}_len_{s}']
        # Computes the (inclusive) endpoints of the N-terminus flank. Note that
        # `flank_start` could be negative.
        if p == PREFIXES[0]:
          flank_start = seq_start - flank_len
          flank_end = seq_start - 1
        # Computes the (inclusive) endpoints of the C-terminus flank. Note that
        # `flank_end` could be larger than the length of the original sequence.
        else:
          flank_start = seq_end + 1
          flank_end = seq_end + flank_len
        region_pair[f'{p}_flank_{s}'] = sample_synthetic_flank(
            rng=rng, length=flank_len)
      # Uniprot flanks are obtained by (brute-force) searching for a combination
      # of UniPort (sub)sequences that satisfy all of the "quality control"
      # criteria in `validate_flank_seeds`. Note that, since each sequence has
      # only a finite number of precomputed flank "seeds" available (for
      # tractability), there is a small but non-zero probability that no flank
      # combination is valid. In these (rare) cases, no flanks are added.
      elif flanks == 'uniprot':
        idx = region_pair.get(f'{p}_flank_seed_idx_{s}', None)
        if idx is not None:  # A valid flank combination was found.
          flank_key = region_pair[f'{p}_flank_seed_key_{idx}_{s}']
          flank_seq = region_pair[f'{p}_flank_seed_sequence_{idx}_{s}']

          flank_acc, flank_endpoints = flank_key.split('/')
          flank_start, flank_end = [int(x) for x in flank_endpoints.split('-')]
          assert len(flank_seq) == (flank_end - flank_start + 1)

          offset = rng.randint(0, len(flank_seq) - region_pair[f'{p}_len_{s}'])
          flank_start += offset
          flank_end = flank_start + region_pair[f'{p}_len_{s}'] - 1

          region_pair[f'{p}_flank_{s}'] = flank_seq[
              offset:offset + region_pair[f'{p}_len_{s}']]
        else:  # No valid flank combination was found.
          # Marks flank as empty.
          flank_start = 0
          flank_end = -1
      # Unrecognized.
      else:
        raise ValueError(
            f"flanks must be 'synthetic' or 'uniprot'. Got {flanks} instead.")

      # Sets a flank key for inspectability purposes if the flank is not empty.
      if flank_start <= flank_end:
        region_pair[f'{p}_key_{s}'] = f'{flank_acc}/{flank_start}-{flank_end}'
      else:
        region_pair[f'{p}_key_{s}'] = ''

  return region_pair


def extend_sequences(region_pair):
  """Optionally, extends domain sequences with N and C-terminus flanks."""
  for s in SUFFIXES:
    seq_start = region_pair[f'seq_start_{s}']
    seq_end = region_pair[f'seq_end_{s}']
    n_flank = region_pair.get(f'n_flank_{s}', '')
    c_flank = region_pair.get(f'c_flank_{s}', '')
    sequence = region_pair[f'sequence_{s}']

    region_pair[f'sequence_{s}'] = ''.join(
        [n_flank, sequence[seq_start - 1:seq_end], c_flank])

    if f'ali_start_{s}' in region_pair:
      region_pair[f'ali_start_{s}'] += len(n_flank)

    # Backs up the original sequence.
    region_pair[f'original_sequence_{s}'] = sequence

  return region_pair


def interval_overlaps(
    start,
    end,
    ref_starts,
    ref_ends,
):
  """Computes the overlap between closed intervals."""
  overlaps = np.minimum(end, ref_ends) - np.maximum(start, ref_starts) + 1
  return np.maximum(0, overlaps)


def annotate_regions(
    region_pair,
    extra_margin = 0,
    min_overlap = 1,
):
  """Finds any annotations that overlap with the regions in `region_pair`."""
  for ann_type, id_name in zip(('hmm_hit', 'other_hit'),
                               ('clan_acc', 'type_id')):
    for s in SUFFIXES:
      hit_start = np.asarray(region_pair[f'{ann_type}_seq_start_{s}'])
      hit_end = np.asarray(region_pair[f'{ann_type}_seq_end_{s}'])
      hit_ids = np.asarray(region_pair[f'{ann_type}_{id_name}_{s}'])

      # If the flanks are synthetic or have passed the "quality control" checks
      # in `validate_flank_seeds`, then only annotations within the original
      # Pfam region endpoints should be taken into account.
      seq_start = region_pair[f'seq_start_{s}']
      seq_end = region_pair[f'seq_end_{s}']

      overlaps = interval_overlaps(
          start=seq_start - extra_margin,
          end=seq_end + extra_margin,
          ref_starts=hit_start,
          ref_ends=hit_end)
      indices = overlaps >= min_overlap

      region_pair[f'overlapping_{ann_type}_{s}'] = set(hit_ids[indices])

  return region_pair


def eval_confounding_in_regions(region_pair):
  """Inspects region for edge-cases, such as nested domains."""
  # Checks whether both regions have shared clan annotations other than their
  # own original `clan_acc`s.
  shared_clans = set.intersection(
      *[region_pair[f'overlapping_hmm_hit_{s}'] for s in SUFFIXES])
  # If both regions are labelled as belonging to the same clan, removes this
  # (expected) shared clan annotation from the set.
  if region_pair['homology_label'] > 0:
    shared_clans -= set(region_pair[f'clan_acc_{s}'] for s in SUFFIXES)
  region_pair['shares_clans'] = bool(shared_clans)
  # Checks whether both regions share other types of region annotations.
  for type_id in OTHER_REGIONS:
    region_pair[f'shares_{type_id}'] = all(
        type_id in region_pair[f'overlapping_other_hit_{s}'] for s in SUFFIXES)

  # We mark a region pair as "potentially confounded" for the homology task if
  # the regions are non-homologous and share any annotations in the annotation
  # categories described by `CONFOUNDING_REGIONS`.
  region_pair['maybe_confounded'] = (
      region_pair['homology_label'] == 0 and
      any(region_pair[f'shares_{type_id}'] for type_id in CONFOUNDING_REGIONS))

  return region_pair


def add_bos_and_eos_flags(
    region_pair,
    flanks = None,
):
  """Marks whether the sequences lie at the start/end of a full protein seq."""
  for s in SUFFIXES:
    seq_len = len(region_pair[f'original_sequence_{s}'])
    region_pair[f'bos_{s}'] = region_pair[f'seq_start_{s}'] == 1
    region_pair[f'eos_{s}'] = region_pair[f'seq_end_{s}'] == seq_len
    # In the case of synthetic or uniprot flanks, we only have a biologically
    # relevant start / end of a full sequence if no flanks were added.
    if flanks in ('synthetic', 'uniprot'):
      region_pair[f'bos_{s}'] &= region_pair[f'n_len_{s}'] == 0
      region_pair[f'eos_{s}'] &= region_pair[f'c_len_{s}'] == 0
  return region_pair


def add_extended_region_keys(region_pair):
  """Creates new keys for the sequences, summarizing the new endpoints."""
  for s in SUFFIXES:
    n_key_s = region_pair.get(f'n_key_{s}', '')
    c_key_s = region_pair.get(f'c_key_{s}', '')
    region_pair[f'extended_key_{s}'] = ';'.join(
        [n_key_s, region_pair[f'key_{s}'], c_key_s])
  return region_pair


def compute_alignment_path(region_pair):
  """Computes alignment path from `region_pair`'s gapped sequences."""
  matches, ali_start = alignment.alignment_from_gapped_sequences(
      gapped_sequence_x=region_pair['gapped_sequence_x'],
      gapped_sequence_y=region_pair['gapped_sequence_y'])
  region_pair['matches'] = matches
  region_pair['ali_start_x'] = ali_start[0]
  region_pair['ali_start_y'] = ali_start[1]
  return region_pair


def subsample_region_pairs(
    region_pair,
    count,
    min_count,
    resample_ratio,
    global_seed,
):
  """Randomly discards region pairs to rebalance label distribution."""
  rng = get_prng(region_pair, global_seed)
  keep_prob = resample_ratio * (min_count / count)
  if rng.uniform(0, 1) <= keep_prob:
    yield region_pair


def add_homology_label(region):
  """Computes (ternary) homology label (non-homologs / same clan / same fam)."""
  region['homology_label'] = 0
  # Note: if both sequences have the same Pfam accession (`pfam_acc`), they are
  # guaranteed to also have the same Clan accession (`clan_acc`). The converse
  # does not hold, however.
  for key in ('pfam_acc', 'clan_acc'):
    region['homology_label'] += int(
        region[f'{key}_{SUFFIXES[0]}'] == region[f'{key}_{SUFFIXES[1]}'])
  return region


def process_region_pair_alignment(
    region_pair,
    flanks = None,
    max_len = 511,
    extra_margin = 0,
    min_overlap = 1,
    global_seed = 0,
):
  """Generates a sample for the alignment task from a pair of Pfam regions."""
  rng = get_prng(region_pair, global_seed=global_seed)

  # Parses the gapped sequences in `region_pair` to extract the ground-truth
  # alignment path, described in terms of its starting positions and matches.
  region_pair = compute_alignment_path(region_pair)

  # Optionally, samples synthetic sequences ('synthetic') or UniProt sequences
  # ('uniprot') to generate a ground-truth alignment with flanks.
  if flanks in ('synthetic', 'uniprot'):
    region_pair = generate_flanks(
        rng=rng,
        region_pair=region_pair,
        flanks=flanks,
        max_len=max_len,
        extra_margin=extra_margin,
        min_overlap=min_overlap)
  # Extracts the (sub)sequences to be aligned, optionally including changes to
  # the N-terminus and C-terminus flanks.
  region_pair = extend_sequences(region_pair)
  # Only real flanks should be potentially problematic.
  region_pair['maybe_confounded'] = False
  region_pair['fallback'] = False

  # Adds flags to element indicating if the new endpoints correspond to the
  # start (resp. end) of a full sequence.
  region_pair = add_bos_and_eos_flags(region_pair, flanks)

  # Compresses the set of `matches` into a CIGAR-like state string.
  states = alignment.states_from_matches(region_pair['matches'])
  region_pair['states'] = alignment.compress_states(states)

  # Computes the percent identity of the sequnce pair, based on the ground-truth
  # alignment, taking any modifications to the flanks into account.
  region_pair['percent_identity'] = alignment.pid_from_matches(
      sequence_x=region_pair['sequence_x'],
      sequence_y=region_pair['sequence_y'],
      matches=region_pair['matches'],
      ali_start_x=region_pair['ali_start_x'],
      ali_start_y=region_pair['ali_start_y'])

  # Creates a new key for the sequences, summarizing the new endpoints.
  region_pair = add_extended_region_keys(region_pair)

  return region_pair


def process_region_pair_homology(
    region_pair,
    flanks = None,
    max_len = 511,
    extra_margin = 0,
    min_overlap = 1,
    global_seed = 0,
):
  """Generates a sample for the homology task from a pair of Pfam regions."""
  rng = get_prng(region_pair, global_seed=global_seed)

  # Ground-truth alignments are only available whenever both regions belong to
  # the same Pfam family.
  if region_pair['homology_label'] == 2:
    # Parses the gapped sequences in `region_pair` to extract the ground-truth
    # alignment path, described in terms of its starting positions and matches.
    region_pair = compute_alignment_path(region_pair)

  # Optionally, extends region boundaries ('contextual') or samples synthetic
  # sequences ('synthetic') to generate a ground-truth alignment with flanks.
  if flanks in ('synthetic', 'uniprot'):
    region_pair = generate_flanks(
        rng=rng,
        region_pair=region_pair,
        flanks=flanks,
        max_len=max_len,
        extra_margin=extra_margin,
        min_overlap=min_overlap)
  # Extracts the (sub)sequences to be aligned, optionally including changes to
  # the N-terminus and C-terminus flanks.
  region_pair = extend_sequences(region_pair)

  # Regions may contain shared annotations that could act as confounding
  # factors. We perform a best-effort attempt to detect such cases.
  # However, the incompleteness of annotation databases necessarily implies this
  # step will never be perfect and residual, undetected "confounding" might
  # persist for some region pairs.
  region_pair = annotate_regions(region_pair, extra_margin, min_overlap)
  region_pair = eval_confounding_in_regions(region_pair)

  # Adds flags to element indicating if the new endpoints correspond to the
  # start (resp. end) of a full sequence.
  region_pair = add_bos_and_eos_flags(region_pair, flanks)

  # Ground-truth percent identities for the region pair can only be computed
  # at the highest level of homology, namely, when both regions belong to the
  # same Pfam family.
  if region_pair['homology_label'] == 2:
    # Computes the percent identity of the sequnce pair, based on the
    # ground-truth alignment, taking any modifications to the flanks into
    # account.
    region_pair['percent_identity'] = alignment.pid_from_matches(
        sequence_x=region_pair['sequence_x'],
        sequence_y=region_pair['sequence_y'],
        matches=region_pair['matches'],
        ali_start_x=region_pair['ali_start_x'],
        ali_start_y=region_pair['ali_start_y'])
  else:
    region_pair['percent_identity'] = float('nan')

  # Creates a new key for the sequences, summarizing the new endpoints.
  region_pair = add_extended_region_keys(region_pair)

  return region_pair


def build_pfam_alignments_pipeline(
    file_pattern,
    dataset_splits_path,
    target_split,
    output_path,
    max_len = 511,
    flanks = None,
    extra_margin = 0,
    min_overlap = 1,
    global_seed = 0,
):
  """3a) Returns a pipeline to generate samples for sequence alignment task.

  Args:
    file_pattern: The file pattern from which to read preprocessed Pfam shards.
      This is assumed to be the result of steps 1a), 1b) and, optionally, step
      2) of the full preprocessing pipeline.
      See `preprocess_tables_lib.py` and `uniprot_flanks_lib.py` for additional
      details.
    dataset_splits_path: The path to the key, split mapping file.
    target_split: The dataset split for which to generate pairwise alignment
      data.
    output_path: The path prefix to the output files.
    max_len: The maximum length of sequences to be included in the output
      dataset (without BOS or EOS tokens).
    flanks: The approach to be used add flanking sequences to Pfam regions. If
      `None`, no flanking sequences will be added. Supported modes include
      `synthetic` and `uniprot`.
    extra_margin: Extends sequence boundaries by `extra_margin` residues when
      evaluating overlap between annotations.
    min_overlap: The minimum number of residues in a sequence that need to
      overlap with a region annotation in order for the annotation to be applied
      to the sequence.
    global_seed: A global seed for the PRNG.

  Returns:
  A beam.Pipeline.
  """
  def pipeline(root):
    # Reads data preprocessed by steps 1a) and 1b) in `preprocess_tables_lib.py`
    # and, optionally, step 2) in `uniprot_flanks_lib.py`.
    with_flank_seeds = flanks == 'uniprot'
    fields_to_keep = ALIGNMENT_FIELDS
    if with_flank_seeds:
      fields_to_keep += tuple(f[0] for f in schemas.FLANK_FIELDS)
    regions = (
        root
        | 'ReadParsedPfamData' >> ReadParsedPfamData(
            file_pattern=file_pattern,
            dataset_splits_path=dataset_splits_path,
            fields_to_keep=fields_to_keep,
            with_flank_seeds=with_flank_seeds,
            max_len=max_len,
            filter_by_qc=True))
    # Filters sequences not belonging to the target split and removes the no
    # longer needed split field.
    filtered_regions = (
        regions
        | 'FilterBySplit' >> beam.Filter(
            lambda x: x['split'] == target_split)
        | 'DropSplitField' >> beam.Map(
            functools.partial(
                utils.drop_record_fields,
                fields_to_drop=['split'])))
    # Enumerates all pairs of regions sharing the same Pfam accession. Each
    # region pair is processed to produce the final fields that will be used for
    # training and evaluating the models on the pairwise alignment task.
    region_pairs = (
        filtered_regions
        | 'EnumerateAllFamilyPairs' >> utils.Combinations(
            groupby_field='pfam_acc',
            key_field='key',
            num_samples=None,
            suffixes=SUFFIXES)
        | 'ProcessRegionPairs' >> beam.Map(
            functools.partial(
                process_region_pair_alignment,
                flanks=flanks,
                max_len=max_len,
                extra_margin=extra_margin,
                min_overlap=min_overlap,
                global_seed=global_seed)))
    # Writes postprocessed region pairs to disk as tab-delimited sharded text
    # files.
    _ = (
        region_pairs
        | 'WriteToTable' >> schemas_lib.WriteToTable(
            file_path_prefix=output_path,
            schema_cls=schemas.PairwiseAlignmentRow))

  return pipeline


def build_pfam_homology_pipeline(
    file_pattern,
    dataset_splits_path,
    target_split,
    output_path,
    avg_num_samples,
    prob_pos_different_family = 0.11,
    prob_neg = 0.5,
    max_len = 511,
    flanks = None,
    extra_margin = 0,
    min_overlap = 1,
    global_seed = 0,
):
  """3b) Returns a pipeline to generate samples for homology detection task.

  Args:
    file_pattern: The file pattern from which to read preprocessed Pfam shards.
      This is assumed to be the result of steps 1a), 1b) and, optionally, step
      2) of the full preprocessing pipeline.
      See `preprocess_tables_lib.py` and `uniprot_flanks_lib.py` for additional
      details.
    dataset_splits_path: The path to the key, split mapping file.
    target_split: The dataset split for which to generate pairwise alignment
      data.
    output_path: The path prefix to the output files.
    avg_num_samples: The (expected) number of samples (sequence pairs) to
      subsample (homologous and non-homologous).
    prob_pos_different_family: The (expected) proportion of samples consisting
      of region pairs in the same clan but different families.
    prob_neg: The (expected) proportion of samples consisting of non-homologous
      region pairs, that is, regions in different clans.
    max_len: The maximum length of sequences to be included in the output
      dataset.
    flanks: The approach to be used add flanking sequences to Pfam
      regions. If `None`, no flanking sequences will be added. Supported modes
      include `synthetic` and `uniprot`.
    extra_margin: Extends sequence boundaries by `extra_margin` residues when
      evaluating overlap between annotations.
    min_overlap: The minimum number of residues in a sequence that need to
      overlap with a region annotation in order for the annotation to be applied
      to the sequence.
    global_seed: A global seed for the PRNG.

  Returns:
  A beam.Pipeline.
  """
  def pipeline(root):
    # Reads data preprocessed by steps 1a) and 1b) in `preprocess_tables_lib.py`
    # and, optionally, step 2) in `uniprot_flanks_lib.py`.
    with_flank_seeds = flanks == 'uniprot'
    fields_to_keep = ALIGNMENT_FIELDS
    if with_flank_seeds:
      fields_to_keep += tuple(f[0] for f in schemas.FLANK_FIELDS)
    regions = (
        root
        | 'ReadParsedPfamData' >> ReadParsedPfamData(
            file_pattern=file_pattern,
            dataset_splits_path=dataset_splits_path,
            fields_to_keep=fields_to_keep,
            with_flank_seeds=with_flank_seeds,
            max_len=max_len,
            filter_by_qc=True))
    # Filters sequences not belonging to the target split and removes the no
    # longer needed split field.
    filtered_regions = (
        regions
        | 'FilterBySplit' >> beam.Filter(
            lambda x: x['split'] == target_split)
        | 'DropSplitField' >> beam.Map(
            functools.partial(
                utils.drop_record_fields,
                fields_to_drop=['split'])))

    # Enumerates a subsample of (on average) `avg_num_samples` pairs of
    # homologous and non-homologous regions, keeping only the latter and adding
    # a class label for each pair:
    # + neg: homology_label = 0, if non-homologs (different clan).
    # + mid: homology_label = 1, if remote homologs (same clan, different fams).
    # + pos: homology_label = 2, if homologs (same family).
    # This subsample will be heavily biased towards negative samples (neg).
    neg_region_pairs = (
        filtered_regions
        | 'EnumerateNegRegionPairs' >> utils.SubsampleOuterProduct(
            avg_num_samples=avg_num_samples,
            groupby_field=None,
            key_field='key',
            suffixes=SUFFIXES)
        | 'KeepNegRegionPairs' >> beam.Filter(
            lambda x: x['clan_acc_x'] != x['clan_acc_y'])
        | 'AddHomologyLabelNegRegionPairs' >> beam.Map(add_homology_label))

    # Enumerates a subsample of (on average) `avg_num_samples` pairs of
    # homologous regions (either in the same family, or in the same clan but
    # in different families), keeping only the latter and adding a class label
    # for each pair.
    # This subsample will be heavily biased towards samples in the same clan but
    # in different families (mid).
    mid_region_pairs = (
        filtered_regions
        | 'EnumerateMidRegionPairs' >> utils.SubsampleOuterProduct(
            avg_num_samples=avg_num_samples,
            groupby_field='clan_acc',
            key_field='key',
            suffixes=SUFFIXES)
        | 'KeepMidRegionPairs' >> beam.Filter(
            lambda x: x['pfam_acc_x'] != x['pfam_acc_y'])
        | 'AddHomologyLabelMidRegionPairs' >> beam.Map(add_homology_label))

    # Enumerates all pairs of regions sharing the same Pfam accession, adding
    # the corresponding homology label (2) to all of the resulting records.
    pos_region_pairs = (
        filtered_regions
        | 'EnumeratePosRegionPairs' >> utils.Combinations(
            groupby_field='pfam_acc',
            key_field='key',
            num_samples=None,
            suffixes=SUFFIXES)
        | 'AddHomologyLabelPosRegionPairs' >> beam.Map(add_homology_label))

    # Computes the number of region pairs in each category. We assume that
    # homologs (same family), i.e., `count_pos`, is the smallest. That holds
    # true for Pfam-A seed 34.0.
    count_pos = pos_region_pairs | 'CountPos' >> beam.combiners.Count.Globally()
    count_mid = mid_region_pairs | 'CountMid' >> beam.combiners.Count.Globally()
    count_neg = neg_region_pairs | 'CountNeg' >> beam.combiners.Count.Globally()

    # Downsamples each category to obtain data with the desired class label
    # distribution.
    prob_pos_same_family = 1.0 - prob_pos_different_family - prob_neg
    assert prob_pos_same_family > 0

    mid_region_pairs = (
        mid_region_pairs
        | 'DownsampleMidRegionPairs' >> beam.FlatMap(
            subsample_region_pairs,
            beam.pvalue.AsSingleton(count_mid),
            beam.pvalue.AsSingleton(count_pos),
            resample_ratio=prob_pos_different_family / prob_pos_same_family,
            global_seed=global_seed))
    neg_region_pairs = (
        neg_region_pairs
        | 'DownsampleNegRegionPairs' >> beam.FlatMap(
            subsample_region_pairs,
            beam.pvalue.AsSingleton(count_neg),
            beam.pvalue.AsSingleton(count_pos),
            resample_ratio=prob_neg / prob_pos_same_family,
            global_seed=global_seed))

    # Homologous and non-homologous regions are merged. The regions pairs are
    # then processed to produce the final fields that will be used for training
    # and evaluating the models on the pairwise homology detection task.
    region_pairs = (
        (pos_region_pairs, mid_region_pairs, neg_region_pairs)
        | 'MergeAllClasses' >> beam.Flatten()
        | 'ReshuffleAfterMerging' >> beam.Reshuffle()
        | 'ProcessRegionPairs' >> beam.Map(
            functools.partial(
                process_region_pair_homology,
                flanks=flanks,
                max_len=max_len,
                extra_margin=extra_margin,
                min_overlap=min_overlap,
                global_seed=global_seed)))
    # Writes postprocessed region pairs to disk as tab-delimited sharded text
    # files.
    _ = (
        region_pairs
        | 'WriteToTable' >> schemas_lib.WriteToTable(
            file_path_prefix=output_path,
            schema_cls=schemas.PairwiseHomologyRow))

  return pipeline
